{
 "cells": [
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import interp\n",
    "from sklearn import metrics\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import dgl\n",
    "import dgl.function as fn"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:57:23.559380300Z",
     "start_time": "2024-05-08T00:57:16.874641Z"
    }
   },
   "id": "64ab8305ebacbbbe",
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Load data"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "cc88fe43a7e086e1"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def load_data(directory, random_seed):\n",
    "    A_DME_D = pd.read_excel(directory + '/A_DME_D.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    A_DME_ME = pd.read_excel(directory + '/A_DME_ME.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    A_DMI_D = pd.read_excel(directory + '/A_DMI_D.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    A_MIME_ME = pd.read_excel(directory + '/A_MIME_ME.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    D_SSM = pd.read_excel(directory + '/disease_Semantic_simi.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    D_GSM = pd.read_excel(directory + '/disease_Gaussian_Simi.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    ME_FSM = pd.read_excel(directory + '/metabolite_func_simi.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    ME_GSM = pd.read_excel(directory + '/metabolite_Gaussian_Simi.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    MI_GSM_1 = pd.read_excel(directory + '/microbe_Gaussian_Simi_1.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    MI_GSM_2 = pd.read_excel(directory + '/microbe_Gaussian_Simi_2.xlsx', header=0, sheet_name='Sheet1').to_numpy()\n",
    "    all_associations = pd.read_excel(directory + '/association_DME.xlsx', header=0, sheet_name='Sheet1', names=['disease', 'metabolite', 'label'])\n",
    "    DMI_associations = pd.read_excel(directory + '/association_DMI.xlsx', header=0, sheet_name='Sheet1', names=['disease', 'microbe', 'label'])\n",
    "    MIME_associations = pd.read_excel(directory + '/association_MIME.xlsx', header=0, sheet_name='Sheet1', names=['microbe', 'metabolite', 'label'])\n",
    "    IMI = (MI_GSM_1 + MI_GSM_2)/2\n",
    "    ID = D_SSM\n",
    "    IME = ME_FSM\n",
    "    for i in range(D_SSM.shape[0]):\n",
    "        for j in range(D_SSM.shape[1]):\n",
    "            if ID[i][j] == 0:\n",
    "                ID[i][j] = D_GSM[i][j]\n",
    "\n",
    "    for i in range(ME_FSM.shape[0]):\n",
    "        for j in range(ME_FSM.shape[1]):\n",
    "            if IME[i][j] == 0:\n",
    "                IME[i][j] = ME_GSM[i][j]\n",
    "    known_associations = all_associations.loc[all_associations['label'] == 1]\n",
    "    unknown_associations = all_associations.loc[all_associations['label'] == 0]\n",
    "    random_negative = unknown_associations.sample(n=known_associations.shape[0], random_state=random_seed, axis=0)\n",
    "\n",
    "    DMI_associations1 = DMI_associations.loc[DMI_associations['label'] == 1]\n",
    "    MIME_associations1 = MIME_associations.loc[MIME_associations['label'] == 1]\n",
    "    sample_df = known_associations.append(random_negative)  # 正负样本1:1采样\n",
    "\n",
    "    # 指针重置\n",
    "    sample_df.reset_index(drop=True, inplace=True)\n",
    "    DMI_associations1.reset_index(drop=True, inplace=True)\n",
    "    MIME_associations1.reset_index(drop=True, inplace=True)\n",
    "    samples = sample_df.values      # 获得重新编号的新样本\n",
    "    DMI_associations = DMI_associations1.values\n",
    "    MIME_associations = MIME_associations1.values\n",
    "    return A_DME_D, A_DME_ME, A_DMI_D, A_MIME_ME, ID, IME, IMI, samples, DMI_associations, MIME_associations"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:57:23.584142600Z",
     "start_time": "2024-05-08T00:57:23.571380200Z"
    }
   },
   "id": "f05e7ecd135684aa",
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Constructing semantic networks under different metapaths"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b247c08196c5032d"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def build_graph(directory, random_seed, device,sample_num):\n",
    "    A_DME_D, A_DME_ME, A_DMI_D, A_MIME_ME, ID, IME, IMI, samples, DMI_associations, MIME_associations = load_data(directory, random_seed)\n",
    "    # 构造D-ME-D元路径的图\n",
    "    g_D0 = dgl.DGLGraph().to(device)\n",
    "    g_D0.add_nodes(A_DME_D.shape[0])\n",
    "    g_D0 = dgl.add_self_loop(g_D0).to(device)\n",
    "    rows, cols = np.where(A_DME_D == 1)\n",
    "    g_D0.add_edges(rows, cols)\n",
    "    d_sim = torch.zeros(g_D0.number_of_nodes(), ID.shape[1]).to(device)\n",
    "    d_sim[:, :] = torch.from_numpy(ID.astype('float32')).to(device)\n",
    "    g_D0.ndata['d_sim'] = d_sim.to(device)\n",
    "\n",
    "    # 构造ME-D-ME元路径的图\n",
    "    New_adj1 = sage_sample(A_DME_ME, IME,sample_num)\n",
    "    g_ME0 = dgl.DGLGraph().to(device)\n",
    "    g_ME0.add_nodes(New_adj1.shape[1])\n",
    "    g_ME0 = dgl.add_self_loop(g_ME0).to(device)\n",
    "    rows, cols = np.where(New_adj1 == 1)\n",
    "    g_ME0.add_edges(rows, cols)\n",
    "    me_sim = torch.zeros(g_ME0.number_of_nodes(), IME.shape[1]).to(device)\n",
    "    me_sim[:, :] = torch.from_numpy(IME.astype('float32')).to(device)\n",
    "    g_ME0.ndata['me_sim'] = me_sim.to(device)\n",
    "\n",
    "    # 构造D-MI-D元路径的图\n",
    "    g_D1 = dgl.DGLGraph().to(device)\n",
    "    g_D1.add_nodes(A_DMI_D.shape[0])\n",
    "    g_D1 = dgl.add_self_loop(g_D1).to(device)\n",
    "    rows, cols = np.where(A_DMI_D == 1)\n",
    "    g_D1.add_edges(rows, cols)\n",
    "    d_sim = torch.zeros(g_D1.number_of_nodes(), ID.shape[1]).to(device)\n",
    "    d_sim[:, :] = torch.from_numpy(ID.astype('float32')).to(device)\n",
    "    g_D1.ndata['d_sim'] = d_sim.to(device)\n",
    "\n",
    "    # 构造ME-MI-ME元路径的图\n",
    "    New_adj2 = sage_sample(A_MIME_ME, IME, sample_num)\n",
    "    g_ME1 = dgl.DGLGraph().to(device)\n",
    "    g_ME1.add_nodes(New_adj2.shape[1])\n",
    "    g_ME1 = dgl.add_self_loop(g_ME1).to(device)\n",
    "    rows, cols = np.where(New_adj2 == 1)\n",
    "    g_ME1.add_edges(rows, cols)\n",
    "\n",
    "    me_sim = torch.zeros(g_ME1.number_of_nodes(), IME.shape[1]).to(device)\n",
    "    me_sim[:, :] = torch.from_numpy(IME.astype('float32')).to(device)\n",
    "    g_ME1.ndata['me_sim'] = me_sim.to(device)\n",
    "\n",
    "    g = dgl.DGLGraph().to(device)\n",
    "    g.add_nodes(ID.shape[0] + IME.shape[0])\n",
    "    node_type = torch.zeros(g.number_of_nodes(), dtype=torch.int64).to(device)\n",
    "    node_type[: ID.shape[0]] = 1\n",
    "    g.ndata['type'] = node_type.to(device)\n",
    "\n",
    "    d_sim = torch.zeros(g.number_of_nodes(), ID.shape[1]).to(device)\n",
    "    d_sim[: ID.shape[0], :] = torch.from_numpy(ID.astype('float32')).to(device)\n",
    "    g.ndata['d_sim'] = d_sim.to(device)\n",
    "\n",
    "    me_sim = torch.zeros(g.number_of_nodes(), IME.shape[1]).to(device)\n",
    "    me_sim[ID.shape[0]: ID.shape[0]+IME.shape[0], :] = torch.from_numpy(IME.astype('float32')).to(device)\n",
    "    g.ndata['me_sim'] = me_sim.to(device)\n",
    "\n",
    "    disease_ids = list(range(1, ID.shape[0]+1))\n",
    "    metabolite_ids = list(range(1, IME.shape[0]+1))\n",
    "\n",
    "    disease_ids_invmap = {id_: i for i, id_ in enumerate(disease_ids)}\n",
    "    metabolite_ids_invmap = {id_: i for i, id_ in enumerate(metabolite_ids)}\n",
    "\n",
    "    sample_disease_vertices = [disease_ids_invmap[id_] for id_ in samples[:, 0]]\n",
    "    sample_metabolite_vertices = [metabolite_ids_invmap[id_] + ID.shape[0] for id_ in samples[:, 1]]\n",
    "\n",
    "    g.add_edges(sample_disease_vertices, sample_metabolite_vertices,\n",
    "                data={'label': torch.from_numpy(samples[:, 2].astype('float32'))})\n",
    "    g.add_edges(sample_metabolite_vertices, sample_disease_vertices,   # 添加双向边（无向）\n",
    "                data={'label': torch.from_numpy(samples[:, 2].astype('float32'))})\n",
    "\n",
    "    g0 = dgl.DGLGraph().to(device)\n",
    "    g0.add_nodes(ID.shape[0] + IME.shape[0] + IMI.shape[0])\n",
    "    node_type = torch.zeros(g0.number_of_nodes(), dtype=torch.int64).to(device)\n",
    "    node_type[: ID.shape[0]] = 1\n",
    "    node_type[ID.shape[0] + IME.shape[0]:] = 2\n",
    "    g0.ndata['type'] = node_type.to(device)\n",
    "\n",
    "    d_sim = torch.zeros(g0.number_of_nodes(), ID.shape[1]).to(device)\n",
    "    d_sim[: ID.shape[0], :] = torch.from_numpy(ID.astype('float32')).to(device)\n",
    "    g0.ndata['d_sim'] = d_sim.to(device)\n",
    "\n",
    "    me_sim = torch.zeros(g0.number_of_nodes(), IME.shape[1]).to(device)\n",
    "    me_sim[ID.shape[0]: ID.shape[0]+IME.shape[0], :] = torch.from_numpy(IME.astype('float32')).to(device)\n",
    "    g0.ndata['me_sim'] = me_sim.to(device)\n",
    "\n",
    "    mi_sim = torch.zeros(g0.number_of_nodes(), IMI.shape[1]).to(device)\n",
    "    mi_sim[ID.shape[0]+IME.shape[0]: ID.shape[0]+IME.shape[0]+IMI.shape[0], :] = torch.from_numpy(IMI.astype('float32')).to(device)\n",
    "    g0.ndata['mi_sim'] = mi_sim.to(device)\n",
    "\n",
    "    microbe_ids = list(range(1, IMI.shape[0]+1))\n",
    "    microbe_ids_invmap = {id_: i for i, id_ in enumerate(microbe_ids)}\n",
    "\n",
    "    dmi_disease_vertices = [disease_ids_invmap[id_] for id_ in DMI_associations[:, 0]]\n",
    "    dmi_microbe_vertices = [microbe_ids_invmap[id_] + ID.shape[0] + IME.shape[0] for id_ in DMI_associations[:, 1]]\n",
    "    mime_microbe_vertices = [microbe_ids_invmap[id_] + ID.shape[0] + IME.shape[0] for id_ in MIME_associations[:, 0]]\n",
    "    mime_metabolite_vertices = [metabolite_ids_invmap[id_] + ID.shape[0] for id_ in MIME_associations[:, 1]]\n",
    "\n",
    "    g0.add_edges(sample_disease_vertices, sample_metabolite_vertices,\n",
    "                data={'dme': torch.from_numpy(samples[:, 2].astype('float32'))})\n",
    "    g0.add_edges(sample_metabolite_vertices, sample_disease_vertices,\n",
    "                data={'med': torch.from_numpy(samples[:, 2].astype('float32'))})\n",
    "    g0.add_edges(dmi_disease_vertices, dmi_microbe_vertices,\n",
    "                data={'dmi': torch.from_numpy(DMI_associations[:, 2].astype('float32'))})\n",
    "    g0.add_edges(dmi_microbe_vertices, dmi_disease_vertices,\n",
    "                data={'mid': torch.from_numpy(DMI_associations[:, 2].astype('float32'))})\n",
    "    g0.add_edges(mime_microbe_vertices, mime_metabolite_vertices,\n",
    "                data={'mime': torch.from_numpy(MIME_associations[:, 2].astype('float32'))})\n",
    "    g0.add_edges(mime_metabolite_vertices, mime_microbe_vertices,\n",
    "                data={'memi': torch.from_numpy(MIME_associations[:, 2].astype('float32'))})\n",
    "    return g_D0, g_ME0, g_D1, g_ME1, g, g0, sample_disease_vertices, sample_metabolite_vertices, ID, IME, IMI, samples, DMI_associations, MIME_associations"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:57:23.637137100Z",
     "start_time": "2024-05-08T00:57:23.602133500Z"
    }
   },
   "id": "1fe63c66b0043393",
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "source": [
    "# The Sampling Strategy of the enhanced GraphSAGE Model"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "bc006d9030b227a1"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def sage_sample(Adj, Fea, num_neighbors):\n",
    "    num_nodes = Adj.shape[0]\n",
    "    weights_matrix = np.zeros((num_nodes, num_nodes))\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(num_nodes):\n",
    "            if Adj[i][j] == 1:\n",
    "                weight = np.dot(Fea[i], Fea[j])\n",
    "                weights_matrix[i][j] = weight\n",
    "    top_k_neighbors = num_neighbors\n",
    "    num_nodes = weights_matrix.shape[0]\n",
    "    new_adj_matrix = np.zeros_like(weights_matrix)\n",
    "    for i in range(num_nodes):\n",
    "        weights_vector = weights_matrix[i]\n",
    "        nonzero_indices = np.where(weights_vector != 0)[0]\n",
    "        nonzero_weights = weights_vector[nonzero_indices]\n",
    "        sorted_indices = np.argsort(nonzero_weights)[::-1]  # 降序排序\n",
    "        selected_indices = nonzero_indices[sorted_indices[:top_k_neighbors]]\n",
    "        new_adj_matrix[i, selected_indices] = 1\n",
    "    return new_adj_matrix"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:57:23.638129600Z",
     "start_time": "2024-05-08T00:57:23.629150800Z"
    }
   },
   "id": "12c86a89dea11b75",
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Reset weight parameters"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "345a8b884613d34"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def weight_reset(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "        m.reset_parameters()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:57:23.664141500Z",
     "start_time": "2024-05-08T00:57:23.642131900Z"
    }
   },
   "id": "ad180a098b5dccab",
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "source": [
    "# ROC curve"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ca98895c2666abc"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def plot_auc_curves(fprs, tprs, auc, directory, name):\n",
    "    mean_fpr = np.linspace(0, 1, 20000)\n",
    "    tpr = []\n",
    "\n",
    "    for i in range(len(fprs)):\n",
    "        tpr.append(interp(mean_fpr, fprs[i], tprs[i]))\n",
    "        tpr[-1][0] = 0.0\n",
    "        plt.plot(fprs[i], tprs[i], alpha=0.4, linestyle='--', label='Fold %d AUC: %.4f' % (i + 1, auc[i]))\n",
    "\n",
    "    mean_tpr = np.mean(tpr, axis=0)\n",
    "    mean_tpr[-1] = 1.0\n",
    "    # mean_auc = metrics.auc(mean_fpr, mean_tpr)\n",
    "    mean_auc = np.mean(auc)\n",
    "    auc_std = np.std(auc)\n",
    "    plt.plot(mean_fpr, mean_tpr, color='BlueViolet', alpha=0.9, label='Mean AUC: %.4f' % mean_auc)\n",
    "    plt.plot([0, 1], [0, 1], linestyle='--', color='black', alpha=0.4)\n",
    "    plt.xlim([-0.05, 1.05])\n",
    "    plt.ylim([-0.05, 1.05])\n",
    "    plt.xlabel('False Positive Rate')\n",
    "    plt.ylabel('True Positive Rate')\n",
    "    plt.title('ROC curve')\n",
    "    plt.legend(loc='lower right')\n",
    "    plt.savefig(directory+'/%s.pdf' % name, dpi=300, bbox_inches='tight')\n",
    "    plt.close()\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:57:23.677130300Z",
     "start_time": "2024-05-08T00:57:23.662131600Z"
    }
   },
   "id": "409117a7b7bc7665",
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "source": [
    "# PR curve"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "54362cd610a0ffc1"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def plot_prc_curves(precisions, recalls, prc, directory, name):\n",
    "    mean_recall = np.linspace(0, 1, 20000)\n",
    "    precision = []\n",
    "    for i in range(len(recalls)):\n",
    "        precision.append(interp(1-mean_recall, 1-recalls[i], precisions[i]))\n",
    "        precision[-1][0] = 1.0\n",
    "        plt.plot(recalls[i], precisions[i], alpha=0.4, linestyle='--', label='Fold %d AUPR: %.4f' % (i + 1, prc[i]))\n",
    "    mean_precision = np.mean(precision, axis=0)\n",
    "    mean_precision[-1] = 0\n",
    "    # mean_prc = metrics.auc(mean_recall, mean_precision)\n",
    "    mean_prc = np.mean(prc)\n",
    "    prc_std = np.std(prc)\n",
    "    plt.plot(mean_recall, mean_precision, color='BlueViolet', alpha=0.9,\n",
    "             label='Mean AUPR: %.4f' % mean_prc)  # AP: Average Precision\n",
    "    plt.plot([1, 0], [0, 1], linestyle='--', color='black', alpha=0.4)\n",
    "    plt.xlim([-0.05, 1.05])\n",
    "    plt.ylim([-0.05, 1.05])\n",
    "    plt.xlabel('Recall')\n",
    "    plt.ylabel('Precision')\n",
    "    plt.title('PR curve')\n",
    "    plt.legend(loc='lower left')\n",
    "    plt.savefig(directory + '/%s.pdf' % name, dpi=300, bbox_inches='tight')\n",
    "    plt.close()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:57:23.692129400Z",
     "start_time": "2024-05-08T00:57:23.675129900Z"
    }
   },
   "id": "61f4f0762edd5c02",
   "execution_count": 7
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
